import numpy as np
import torch
import os
from icp import ICP


class Agents:
    def __init__(self, args):
        self.args = args
        self.policy = ICP(args)

    def select_actions(self, s, message, available_actions, hidden, eps):

        o = torch.from_numpy(s).float().reshape(self.args.vec_env * self.args.n_agents, -1)

        # m = torch.stack([message]*self.args.n_agents, dim=1).reshape(self.args.vec_env * self.args.n_agents, -1)
        m = []
        for i in range(self.args.n_agents):
            m.append(
                torch.stack([message[:, (i + j) % self.args.n_agents] if j != 0
                             else message[:, (i + j) % self.args.n_agents].detach()
                             for j in range(self.args.n_agents)],
                            dim=1)
            )
        m = torch.stack(m, dim=1).reshape(self.args.vec_env * self.args.n_agents, -1)

        a_u = torch.from_numpy(available_actions).float().reshape(self.args.vec_env * self.args.n_agents, -1)
        h = hidden.reshape(self.args.rnn_layers, self.args.vec_env * self.args.n_agents, -1)

        if self.args.cuda:
            o = o.cuda()
            a_u = a_u.cuda()

        action, q_value, m_out, h_out = self.policy.ICN(o, m, a_u, h, eps=eps)

        us = action.detach().cpu().reshape(self.args.vec_env, self.args.n_agents, -1).numpy()
        q_value = q_value.reshape(self.args.vec_env, self.args.n_agents, -1)
        hs = h_out.reshape(self.args.rnn_layers, self.args.vec_env, self.args.n_agents, -1)
        detached_ms = m_out.detach().cpu().reshape(self.args.vec_env, self.args.n_agents, -1).numpy()
        ms = m_out.reshape(self.args.vec_env, self.args.n_agents, -1)

        return us, q_value, hs, detached_ms, ms

    def compute_target(self, s, message, available_actions, hidden):
        o = torch.from_numpy(s).float().reshape(self.args.vec_env * self.args.n_agents, -1)

        m = []
        for i in range(self.args.n_agents):
            m.append(
                torch.stack([message[:, (i + j) % self.args.n_agents] if j != 0
                             else message[:, (i + j) % self.args.n_agents].detach()
                             for j in range(self.args.n_agents)],
                            dim=1)
            )
        m = torch.stack(m, dim=1).reshape(self.args.vec_env * self.args.n_agents, -1)

        a_u = torch.from_numpy(available_actions).float().reshape(self.args.vec_env * self.args.n_agents, -1)
        h = hidden.reshape(self.args.rnn_layers, self.args.vec_env * self.args.n_agents, -1)

        if self.args.cuda:
            o = o.cuda()
            a_u = a_u.cuda()

        _, q_value, m_out, h_out = self.policy.ICN_target(o, m, a_u, h)

        q_value = q_value.reshape(self.args.vec_env, self.args.n_agents, -1)
        hs = h_out.reshape(self.args.rnn_layers, self.args.vec_env, self.args.n_agents, -1)
        ms = m_out.reshape(self.args.vec_env, self.args.n_agents, -1)

        return q_value, hs, ms

    def learn(self, transitions):
        loss = self.policy.loss(transitions)

        self.policy.trained_step(loss)
        return loss

    def init_hidden(self):
        if self.args.cuda:
            return torch.zeros(self.args.rnn_layers, self.args.vec_env, self.args.n_agents, self.args.hidden_size, dtype=torch.float32).cuda()
        else:
            return torch.zeros(self.args.rnn_layers, self.args.vec_env, self.args.n_agents, self.args.hidden_size, dtype=torch.float32)

    def init_message(self):
        if self.args.cuda:
            m = torch.zeros(self.args.vec_env, self.args.n_agents, self.args.message_shape, dtype=torch.float32).cuda()
        else:
            m = torch.zeros(self.args.vec_env, self.args.n_agents, self.args.message_shape, dtype=torch.float32)
        m[:, :, -1] = 1
        return m